from whisper.model import Linear, AudioEncoder, TextDecoder
import torch
if __name__ == "__main__":
    model = AudioEncoder(n_mels=80, n_ctx=512, n_state=128, n_head=8, n_layer=2)
    input = torch.randn(10, 80, 512)
    output = model(input)
    print(output.shape)
    decoder = TextDecoder(n_vocab=20000, n_ctx=120, n_state=128, n_head=8, n_layer=2)
    text_idx_list = torch.randint(0, 20000, (10, 120))
    xa = torch.randn(10, 512, 128)
    output = decoder(text_idx_list, xa)
    print(output.shape)
